package mtask

import (
	"sync"
	"sync/atomic"
)

type WorkerGroup struct {
	work    chan int
	lock    *sync.RWMutex
	current *atomic.Int32
	total   *atomic.Int32
	Worker  *WorkerChain
}

func NewWorkGroup(workerChain *WorkerChain) *WorkerGroup {
	return &WorkerGroup{
		work:    make(chan int),
		current: &atomic.Int32{},
		total:   &atomic.Int32{},
		lock:    &sync.RWMutex{},
		Worker:  workerChain,
	}
}

func (ce *WorkerGroup) Submit(fn func(chainId int)) {
	ce.AddTotal()
	ce.Worker.Submit(func(chainId int) {
		defer func() {
			ce.work <- 1
		}()
		fn(chainId)
	})
}
func (ce *WorkerGroup) SubmitWorker(fn Worker) {
	ce.AddTotal()
	ce.Worker.Submit(func(chainId int) {
		defer func() {
			ce.work <- 1
		}()
		fn.Run(chainId)
	})
}
func (ce *WorkerGroup) Await() {
	for {
		select {
		case <-ce.work:
			ce.AddCurrent()
		}
		if ce.current.Load() >= ce.total.Load() {
			break
		}
	}
}
func (ce *WorkerGroup) AwaitAndClose() {
	ce.Await()
	ce.Worker.Stop()
}

func (ce *WorkerGroup) AddCurrent() {
	ce.current.Add(1)
}
func (ce *WorkerGroup) AddTotal() {
	ce.total.Add(1)
}
